import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm


# ===== Normalization (helper functions)

def normalize_to_neg_one_to_one(img):
    # [0.0, 1.0] -> [-1.0, 1.0]
    return img * 2 - 1


def unnormalize_to_zero_to_one(t):
    # [-1.0, 1.0] -> [0.0, 1.0]
    return (t + 1) * 0.5


class EDM(nn.Module):
    def __init__(self, nn_model,
                 sigma_data, p_mean, p_std, sigma_min, sigma_max, rho,
                 device):
        '''
            EDM proposed by "Elucidating the Design Space of Diffusion-Based Generative Models".
            Args:
                `nn_model`: A network (e.g. UNet) which performs same-shape mapping.
                `device`: The CUDA device that tensors run on.
            Training parameters:
                `sigma_data`, `p_mean`, `p_std`
            Sampling parameters:
                `sigma_min`, `sigma_max`, `rho`
        '''
        super(EDM, self).__init__()
        self.nn_model = nn_model.to(device)
        params = sum(p.numel() for p in nn_model.parameters() if p.requires_grad) / 1e6
        print(f"nn model # params: {params:.1f}")

        self.device = device

        def number_to_torch_device(value):
            return torch.tensor(value).to(device)

        self.sigma_data = number_to_torch_device(sigma_data)
        self.p_mean     = number_to_torch_device(p_mean)
        self.p_std      = number_to_torch_device(p_std)
        self.sigma_min  = number_to_torch_device(sigma_min)
        self.sigma_max  = number_to_torch_device(sigma_max)
        self.rho        = number_to_torch_device(rho)

    def D_x(self, x_noised, sigma):
        '''
            Denoising with network preconditioning.
            Args:
                `x_noised`: The perturbed image tensor.
                `sigma`: The variance (or noise level) tensor.
            Returns:
                The estimated denoised image tensor `x`.
        '''
        # Preconditioning
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_in = 1 / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_noise = sigma.log() / 4

        # Denoising
        F_x = self.nn_model(c_in * x_noised, c_noise.flatten())
        return c_skip * x_noised + c_out * F_x

    def forward(self, x):
        '''
            Training with weighted denoising loss.
            Args:
                `x`: The clean image tensor ranged in `[0, 1]`.
            Returns:
                The weighted MSE loss tensor.
        '''
        x = normalize_to_neg_one_to_one(x)

        # Perturbation
        rnd_normal = torch.randn((x.shape[0], 1, 1, 1)).to(self.device)
        sigma = (rnd_normal * self.p_std + self.p_mean).exp()
        noise = torch.randn_like(x)
        x_noised = x + noise * sigma

        # Weighted Denoising loss
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        loss_4shape = weight * ((x - self.D_x(x_noised, sigma)) ** 2)
        return loss_4shape.mean()


    def edm_sample(self, n_sample, size, notqdm=False, num_steps=18,
                   S_churn=0, S_min=0, S_max=float('inf'), S_noise=1):
        '''
            Sampling with stochastic sampler.
            Args:
                `n_sample`: The batch size.
                `size`: The image shape tuple (e.g. `(3, 32, 32)`).
                `num_steps`: The number of time steps for discretization. Actual NFE is `2 * num_steps - 1`.
                `S_churn`: controls stochasticity. Set `S_churn=0` for deterministic sampling.
            Returns:
                The sampled image tensor.
        '''
        sigma_min, sigma_max, rho = self.sigma_min, self.sigma_max, self.rho

        # Time steps
        times = torch.arange(num_steps, device=self.device)
        times = (sigma_max ** (1 / rho) + times / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        times = torch.cat([torch.as_tensor(times), torch.zeros_like(times[:1])]) # t_N = 0
        time_pairs = list(zip(times[:-1], times[1:]))

        x_next = torch.randn(n_sample, *size).to(self.device) * times[0]
        for i, (t_cur, t_next) in enumerate(tqdm(time_pairs, disable=notqdm)): # 0, ..., N-1
            x_cur = x_next

            # Increase noise temporarily.
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
            t_hat = torch.as_tensor(t_cur + gamma * t_cur)
            x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

            # Euler step.
            denoised = self.D_x(x_hat, t_hat)
            d_cur = (x_hat - denoised) / t_hat
            x_next = x_hat + (t_next - t_hat) * d_cur

            # Apply 2nd order correction.
            if i < num_steps - 1:
                denoised = self.D_x(x_next, t_next)
                d_prime = (x_next - denoised) / t_next
                x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        x_next = unnormalize_to_zero_to_one(x_next)
        return x_next
